import time
import dateparser
import pytz
import json
import ccxt
import pandas as pd
import os
import time
import numpy as np
import requests
import inspect
import shutil
import json
import sys
import click
import argparse
from datetime import timedelta
from datetime import datetime
from datetime import timedelta
from backtrader.utils import logger
from backtrader.utils.tradecalendar import TradingCalendar
from backtrader.utils.datehelper import timestamp2str, str2timestamp
from backtrader.utils.cmd import cli


def filter_cols(x: dict, cols: list):
    new_dict = {}
    for k, v in x.items():
        if k in cols:
            new_dict[k] = v
    return new_dict


default_symbols = [
    "BTC/USDT", "ETH/USDT", "ADA/USDT", "XRP/USDT",
    "DOGE/USDT", "DOT/USDT", "UNI/USDT", "BCH/USDT", "SHIB/USDT",
    "LTC/USDT", "SOL/USDT", "LINK/USDT", "FIL/USDT", "AAVE/USDT",
    "SUSHI/USDT", "CHZ/USDT", "FLOW/USDT", "FTM/USDT", "OKT/USDT",
    "YFI/USDT", "ZEN/USDT", "XLM/USDT", "THETA/USDT",
    "CVC/USDT", "MANA/USDT"
]

_logger = logger.getLogger(__name__)


class FeiXiaohaoApi:
    def __init__(self, url="https://fxhapi.feixiaohao.com"):
        self.url = url

    # with feixiaohao
    def _download_symbols_market(self):
        '''
        "id": "币种代码（唯一主键）",
        "name": "币种英文名称",
        "symbol": "币种的简称",
        "rank": 币种的排名,
        "logo": "币种的logo（webp格式）",
        "logo_png": "币种的logo（非webp格式）",
        "price_usd": 最新价格（单位：美元）,
        "price_btc": 最新价格（单位：BTC）,
        "volume_24h_usd": 24h的成交额（单位：美元）,
        "market_cap_usd": 流通市值（单位：美元）,
        "available_supply": 流通数量,
        "total_supply": 总发行量,
        "max_supply": 最大发行量（最大发行量可能>总发行量，譬如有些币种会主动销毁一部分数量）,
        "percent_change_1h": 1小时涨跌幅,
        "percent_change_24h":24小时涨跌幅,
        "percent_change_7d":7天涨跌幅,
        "last_updated": 行情更新时间（10位unix时间戳）
        '''
        _url = self.url + "/public/v1/ticker?limit=3000"
        r = requests.get(url=_url)
        return json.loads(r.content)

    def fetch_symbols_market(self) -> pd.DataFrame:
        _cols = ["id", "symbol", "rank", "volume_24h_usd", "market_cap_usd",
                 "available_supply", "total_supply", "max_supply", "last_updated"]

        _data = self._download_symbols_market()
        _data1 = list(map(lambda x: filter_cols(x, _cols), _data))
        df = pd.DataFrame(columns=_cols, data=_data1)
        df = df.rename(columns={"id": "name", "symbol": "baseAsset"})
        return df


class OkexDownloaded:
    def __init__(self, tc: TradingCalendar, config: dict):
        cfg = json.load(open(os.path.join(os.path.expanduser('~'), ".jupyter", "config.json"), 'r'))
        self.ws_dir = cfg["ws"]

        self.okex_ws_dir = os.path.join(self.ws_dir, "data/okex/market")

        self.asset_type = config["options"]["fetchMarkets"][0]
        self.timeframe = tc.timeframe
        self.logger = logger.getLogger(__name__)

        spot_dir = os.path.join(self.okex_ws_dir, self.asset_type)
        if not os.path.exists(spot_dir):
            os.mkdir(spot_dir)

        if self.asset_type == 'spot':
            self.limit = 100
        else:
            self.limit = 100

        self.tc = tc
        self.exchange = ccxt.okex5(config)
        self.feixiaohao = FeiXiaohaoApi()

    def update_market(self, active=True):
        """
        asset_type = spot future
        """
        markets_json = None
        market_path = os.path.join(self.okex_ws_dir, self.asset_type, "okex_spot_market.json")

        if os.path.exists(market_path):
            data_json = json.load(open(market_path, 'r'))
            markets_json = data_json["data"]
            last_update = datetime.today() - datetime.strptime(data_json["last_update"], self.tc.format_str)
        else:
            last_update = timedelta(days=100)

        if last_update.days > 1:
            # ok 已经在前面options设置过了
            markets_json = self.exchange.fetch_markets()
            data_json = {
                "exchange": "okex",
                "data": markets_json,
                "type": self.asset_type,
                "last_update": datetime.today().strftime(self.tc.format_str)
            }
            json.dump(data_json, open(market_path, 'w'), indent=4)

        if active:
            _data1 = list(filter(lambda x: x["active"] == active, markets_json))
        else:
            _data1 = markets_json

        _cols = ['symbol', 'active', 'base', 'quote']
        _data1 = list(map(lambda x: filter_cols(x, _cols), _data1))

        okex_symbols = pd.DataFrame(columns=_cols, data=_data1)
        okex_symbols.rename(columns={
            "active": "status",
            "base": "baseAsset",
            "quote": "quoteAsset"
        }, inplace=True)

        if self.asset_type == "spot":
            # 非小号存在多个相同的symbol的情况, 暂时不要这个
            # feixiaohao_symbols = self.feixiaohao.fetch_symbols_market()
            # feixiaohao_symbols.to_csv(os.path.join(self.okex_ws_dir, self.asset_type, "okex_spot_market_extend.csv"), index=False)

            # mergedf = binance_symbols.merge(feixiaohao_symbols, on="baseAsset", how="left")
            # mergedf = mergedf.sort_values("market_cap_usd", ascending=False)
            # mergedf.to_csv(os.path.join(self.binance_ws_dir, self.asset_type, "binance_spot_market.csv"), index=False)
            # return mergedf
            okex_symbols.to_csv(os.path.join(self.okex_ws_dir, self.asset_type, "okex_spot_market.csv"), index=False)
            return okex_symbols
        elif self.asset_type == 'future':
            okex_symbols.to_csv(os.path.join(self.okex_ws_dir, self.asset_type, "binance_future_market.csv"), index=False)
            return okex_symbols
        elif self.asset_type == "swap":
            okex_symbols.to_csv(os.path.join(self.okex_ws_dir, self.asset_type, "binance_swap_market.csv"), index=False)
            return okex_symbols

    def _fetch_k_lines(self,
                       symbol=None,
                       start_date=None,
                       limit=100):
        """
        - 币安现货是1000限制, 合约是1500限制
        - 这个接口从since开始，limit=2, 包含since返回两个数据
        - 在缺失数据的情况下, 存在中间limit限制下还是返回不了数据
        - **不返回正在泡的bar的数据**
        """
        # init our list
        data = []

        _since = start_date
        while len(data) == 0 and _since < self.tc.latest():
            before = int((datetime.strptime(_since, self.tc.format_str) - timedelta(seconds=10)).timestamp() * 1000)
            after = int(datetime.strptime(self.tc.next(_since, limit), self.tc.format_str).timestamp() * 1000)
            params = {"before": before, "after": after}
            data = self.exchange.fetch_ohlcv(symbol=symbol, timeframe=self.timeframe, limit=limit, params=params)
            # _since = self.tc.next(_since, limit)
        return data

    def try_fetch_k_lines_start_date(self, symbol, start_date):
        _time_frame = "1d"

        limit = 100

        _start_date = datetime.strptime(start_date, self.tc.format_str)
        after = int(_start_date.timestamp() * 1000)
        while True:
            # 如果能取到数据，不包含_end_date, 包含star_date
            # OKT在配置为 HistoryCandles 下, 取不到数据
            data = self.exchange.fetch_ohlcv(symbol=symbol, timeframe=_time_frame, limit=limit, params={"after": after})
            if len(data) != 0:
                after = data[0][0]
            else:
                return timestamp2str(after / 1000)

    def check_calendar_fit(self, data):
        adjust = list(map(lambda x: x if self.tc.is_fit(x / 1000) else self.tc.adjust(x / 1000) * 1000, data))
        return adjust, data

    def is_not_fit(self, data):
        for e in data:
            if not self.tc.is_fit(e):
                return True
        return False

    def update_k_lines(self, symbol=None):

        if "/" in symbol:
            symbol_name = symbol.replace("/", "_")
        else:
            symbol_name = symbol

        timeframe = self.timeframe
        f_dir = os.path.join(self.okex_ws_dir, self.asset_type, f'{timeframe}')
        if not os.path.exists(f_dir):
            os.mkdir(f_dir)
        fpath = os.path.join(f_dir, f"{symbol_name}.csv")
        if os.path.exists(fpath):
            t = pd.read_csv(fpath)
        else:
            t = None
        if t is not None and len(t) > 0:
            last_date = t.iloc[-1]['datetime']
            last_padding_data = t.iloc[-1:]
            first_bar = False
        else:
            first_date = self.try_fetch_k_lines_start_date(
                symbol=symbol,
                start_date=self.tc.last() # ok是从后面往前查
            )
            if not self.tc.is_fit(str2timestamp(first_date)):
                raise Exception(f"try get the first date is not calendar point.")
            last_date = first_date
            last_padding_data = None
            first_bar = True

        # 每次取的时候从当前最后一条取数, 为了解决如果确实可以padding

        """
        这里是把本地时间转成local, 输出的结果又是本地时间
        """
        _limit = self.limit

        # 每次循环都是返回_limit 条数据

        _start_date = self.tc.next(last_date) if not first_bar else last_date

        while True:

            # 这个方法返回从_start_date开始的不包含最新的一个周期的数据(这个周期还没跑完)
            data = self._fetch_k_lines(
                symbol=symbol,
                start_date=_start_date,
                limit=_limit
            )

            if len(data) == 0:
                break

            # 没有fit之前的开始和结束时间
            _data_start_ = timestamp2str(data[0][0] / 1000)
            _data_end_ = timestamp2str(data[-1][0] / 1000)

            # 尽量让数据往前, 避免触发临界的数据
            if _data_end_ > self.tc.last():
                data = list(filter(lambda x: x[0] <= (str2timestamp(self.tc.last()) * 1000), data))
                if len(data) == 0:
                    break
                _data_end_ = timestamp2str(data[-1][0] / 1000)

            if _start_date != _data_start_:
                if not first_bar:
                    self.logger.warning("### [%s] miss data from [%s] - [%s]", symbol, _start_date, _data_start_)
                else:
                    # 第一次的话, 就把取到的第一条数据作为start_date
                    _t = str2timestamp(_data_start_)
                    _start_date = _data_start_ if self.tc.is_fit(_t) else timestamp2str(self.tc.adjust(_t))
                    first_bar = False

            fit_data = list(map(lambda x: x[0] / 1000, data))

            if self.is_not_fit(fit_data):
                _s = timestamp2str(fit_data[0])
                _e = timestamp2str(fit_data[-1])
                self.logger.info("### [%s]-[%s / %s] is not fit trading calendar point, need fit.", symbol, _s, _e)
                fit_data = list(map(lambda x: self.tc.adjust(x), fit_data))

            # 经过fit的开始和结束时间
            _fit_data_start_ = timestamp2str(fit_data[0])
            _fit_data_end_ = timestamp2str(fit_data[-1])

            self.logger.info("fetch [%s]-[%s / %s]", symbol, _data_start_, _data_end_)

            cols = ['datetime', 'open', 'high', 'low', 'close', 'vol']
            tmp = pd.DataFrame(columns=cols, data=data, dtype="float")
            tmp['timestamp'] = tmp['datetime'].astype(int)
            tmp['datetime'] = list(map(lambda x: timestamp2str(x), fit_data))
            tmp = tmp.drop_duplicates("datetime")

            # 这里_start_date 而不是_fit_data_start_, 是因为我取的是_start_date没有, 所以需要padding
            # 从我请求的开始时间到取到数据的结束时间
            base_index = self.tc.range_list(_start_date, _fit_data_end_)
            data_index = tmp["datetime"].tolist()
            if data_index != base_index:
                if last_padding_data is not None:
                    # base_index 把上一个数据index加上
                    padding_datetime = last_padding_data["datetime"].iloc[0]
                    base_index = [padding_datetime] + base_index
                    tmp = pd.concat([last_padding_data, tmp])
                    self.logger.warning("### [%s] miss data, so padding", symbol)

                base = pd.DataFrame(columns=["datetime"], data=base_index)
                tmp = base.merge(tmp, on="datetime", how="left")
                tmp["close"] = tmp["close"].ffill()
                tmp = tmp.fillna(value={"vol": 0})

                # 横向把ohlc跟close填充成一样了, timestamp没有填充
                tmp = tmp.fillna(axis=1, method="bfill")

                # 解决的问题是timestamp填充问题
                row_index = tmp[tmp["timestamp"].isna()].index
                column_index = ["timestamp"]

                fill_timestamp = tmp[tmp["timestamp"].isna()]
                tmp.loc[row_index, column_index] = pd.to_datetime(fill_timestamp["datetime"]).apply(
                    lambda x: int(x.timestamp() * 1000)
                )

                self.logger.info("Padding %s, %s", symbol, len(tmp))

                if last_padding_data is not None:
                    # 最后把第一条数据删掉
                    tmp = tmp.iloc[1:]

            if not os.path.exists(fpath):
                tmp.to_csv(fpath, index=False, mode='a')
            else:
                if len(tmp):
                    tmp.to_csv(fpath, index=False, mode='a', header=False)

            # 每次取的第一条数据从之前的数据重复的, 为了padding
            _start_date = self.tc.next(_fit_data_end_)
            last_padding_data = tmp.iloc[-1:]

    def update_symbol(self, symbol):
        symbols = []
        if symbol == "all".upper():
            df = self.update_market()
            if len(df):
                symbols = (df["baseAsset"] + "/" + df["quoteAsset"]).tolist()

            if self.asset_type == 'spot':
                symbols = list(filter(lambda x: x.endswith("USDT"), symbols))
        elif symbol == "default".upper():
            symbols = default_symbols
        else:
            symbols = symbol.split(",")

        filter_symbols = symbols

        for idx, symbol in enumerate(filter_symbols):
            try:
                self.logger.info("Download [%s], [%s], [%s] - [%s/%s]",
                                 self.asset_type, self.timeframe, symbol, idx + 1, len(filter_symbols))
                self.update_k_lines(symbol=symbol)
            except Exception as e:
                self.logger.info(f"Error:####")
                self.logger.exception(e)
                continue


class Resample:
    def __init__(self, tc: TradingCalendar, time_frame='1m', asset_type="spot"):
        cfg = json.load(open(os.path.join(os.path.expanduser('~'), ".jupyter", "config.json"), 'r'))
        ws_dir = cfg["ws"]
        self.okex_ws_dir = os.path.join(ws_dir, "data/okex/market")
        self.timeframe = time_frame
        self.timeframe2 = self.timeframe.replace("m", "T")
        self.asset_type = asset_type
        self.tc = tc
        self.logger = logger.getLogger(__name__)

    def update_resample_k_lines(self, symbol):

        symbol = symbol.replace("/", "_")
        dest_path = os.path.join(self.okex_ws_dir, self.asset_type, self.timeframe, f"{symbol}.csv")

        source_path = os.path.join(self.okex_ws_dir, self.asset_type, "1m", f"{symbol}.csv")

        if not os.path.exists(source_path):
            self.logger.warning("the source %s 1m ohlc not found.", symbol)
            return

        if os.path.exists(dest_path):
            df = pd.read_csv(dest_path)
            last_date = df.iloc[-1]["datetime"]
            resample_start_tf = self.tc.next(last_date)
        else:
            resample_start_tf = None

        source_df = pd.read_csv(source_path)
        if not len(source_df):
            self.logger.warning("the source %s 1m ohlc data is empty.", symbol)
            return

        source_last_date = source_df.iloc[-1]["datetime"]
        resample_end_tf = timestamp2str(self.tc.adjust(str2timestamp(source_last_date)))

        if resample_start_tf:
            source_df = source_df[source_df["datetime"] >= resample_start_tf]

        if resample_end_tf:
            source_df = source_df[source_df["datetime"] < resample_end_tf]

        if not len(source_df):
            self.logger.warning("the source %s 1m ohlc data not match the %s resample bar.", symbol, self.timeframe)
            return

        func = {
            "open": "first",
            "high": "max",
            "low": "min",
            "close": "last",
            "vol": "sum",
            "timestamp": "first"
        }

        source_df["datetime"] = pd.to_datetime(source_df["datetime"])
        source_df = source_df.set_index("datetime")

        # if self.timeframe2[-1].lower() in ['d', 'w']:
        #     # 币安的天行情从8点开始, utc是从0点开始
        #     resample_df = source_df.resample(self.timeframe2, offset="8h").agg(func)
        # else:
        resample_df = source_df.resample(self.timeframe2).agg(func)

        resample_df = resample_df.reset_index()

        resample_df["vol"] = resample_df["vol"].round(4)

        if len(resample_df) == 0:
            return

        if not os.path.exists(dest_path):
            dir_path = os.path.dirname(dest_path)
            if not os.path.exists(dir_path):
                os.mkdir(dir_path)
            resample_df.to_csv(dest_path, header=True, mode="a", index=False)
        else:
            resample_df.to_csv(dest_path, header=False, mode="a", index=False)

    def update_symbol(self, symbol):
        symbols = []
        if symbol == "all".upper():
            fpath = os.path.join(self.okex_ws_dir, self.asset_type, f"okex_{self.asset_type}_market.csv")
            df = pd.read_csv(fpath)
            if len(df):
                symbols = (df["base"] + "/" + df["quote"]).tolist()
        elif symbol == "default".upper():
            symbols = default_symbols
        else:
            symbols = symbol.split(",")

        filter_symbols = list(filter(lambda x: x.endswith("USDT"), symbols))
        for idx, symbol in enumerate(filter_symbols):
            try:
                self.logger.info("Resample [%s], [%s], [%s] - [%s/%s]",
                                 self.asset_type, self.timeframe, symbol, idx + 1, len(filter_symbols))
                if not symbol.endswith("USDT"):
                    continue

                self.update_resample_k_lines(symbol=symbol)
            except Exception as e:
                self.logger.info(f"Resample Error:####")
                self.logger.exception(e)
                continue


@cli.command()
# @click.option('-d', '--data-bundle-path', default=os.path.expanduser('~/.rqalpha'), type=click.Path(file_okay=False))
@click.option("--symbol", default="default")
@click.option('--asset', type=click.STRING, default="spot")
@click.option('--time_frame', type=click.STRING, default="5m")
def resample(symbol, asset, time_frame):
    if time_frame == "1m":
        _logger.info("Resample don't support %s", time_frame)
        return
    assets = asset.split(",")
    timeframes = time_frame.split(",")

    format_str = "%Y-%m-%d %H:%M:%S"
    today = datetime.today()

    for _asset in assets:
        for _time_frame in timeframes:
            if _time_frame == "1m":
                _logger.info("Resample don't support %s", time_frame)
                continue
            trading_calendar = TradingCalendar(
                start_date_str=init_start_date(_asset),
                end_date_str=today.strftime(format_str),
                time_frame=_time_frame
            )
            _resample = Resample(tc=trading_calendar, time_frame=_time_frame, asset_type=_asset)
            _resample.update_symbol(symbol=symbol.upper())


@cli.command()
# @click.option('-d', '--data-bundle-path', default=os.path.expanduser('~/.rqalpha'), type=click.Path(file_okay=False))
@click.option("--symbol", default="default")
@click.option('--asset', default="spot")
@click.option('--proxy', type=click.STRING, default=None)
def download(symbol, asset, proxy, time_frame="1d"):

    if proxy is not None and proxy == "http":
        proxy = "127.0.0.1:41091"

    assets = asset.split(",")
    for _asset in assets:
        if _asset not in ["spot", "swap", "futures", "option"]:
            raise Exception(f"{_asset} is not valid.")

        config = {
            'enableRateLimit': True,
            'rateLimit': 500,
            'timeout': 10000,
            'options': {
                # 'fetchOHLCV': {
                #     'type': 'HistoryCandles'
                # },
                'fetchMarkets': [_asset]
            }
        }

        if proxy is not None:
            config["proxies"] = {
                'http': proxy,
                'https': proxy
            }
        format_str = "%Y-%m-%d %H:%M:%S"
        today = datetime.today()

        trading_calendar = TradingCalendar(
            start_date_str=init_start_date(_asset),
            end_date_str=today.strftime(format_str),
            time_frame=time_frame,
            offset=1
        )
        downloaded = OkexDownloaded(tc=trading_calendar, config=config)
        downloaded.update_symbol(symbol=symbol.upper())


def init_start_date(asset):
    if asset == "spot":
        _init_start_date = "2019-08-01 00:00:00"
    elif asset == "swap":
        _init_start_date = "2019-08-01 00:00:00"
    else:
        raise Exception(f"{asset} not support")
    return _init_start_date


# download resample auto
@click.group()
def okexdata():
    pass


okexdata.add_command(download)
okexdata.add_command(resample)

if __name__ == "__main__":
    okexdata()
